{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c20da46d-004a-4fd7-bb4d-66112bd82f16",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f7dca27541af4ef58dc615e03fab8141",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
    "\n",
    "base_model_id = \"mistralai/Mixtral-8x7B-v0.1\"\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16\n",
    ")\n",
    "\n",
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    base_model_id,  # Mistral, same as before\n",
    "    quantization_config=bnb_config,  # Same quantization config as before\n",
    "    device_map = \"auto\",\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "\n",
    "pad_token_id = 0\n",
    "eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)\n",
    "eval_tokenizer.pad_token_id = pad_token_id\n",
    "eval_tokenizer.pad_token = eval_tokenizer.convert_ids_to_tokens(pad_token_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f64cce48-4219-4295-8762-0cbbd0027ade",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "你是一个唐诗助手,帮助用户写一首对应要求的唐诗\n",
      "\n",
      "INPUT:\n",
      "作者:李白\n",
      "标签:五言绝句;宫怨\n",
      "\n",
      "OUTPUT:\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "eval_prompt = \"\"\"<s>你是一个唐诗助手,帮助用户写一首对应要求的唐诗\n",
    "\n",
    "INPUT:\n",
    "作者:李白\n",
    "标签:五言绝句;宫怨\n",
    "\n",
    "OUTPUT:\"\"\"\n",
    "model_input = eval_tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "base_model.eval()\n",
    "with torch.no_grad():\n",
    "    print(eval_tokenizer.decode(base_model.generate(**model_input, max_new_tokens=500)[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf1df2c1-7add-41e6-ba71-f3cc0601ec26",
   "metadata": {},
   "source": [
    "# test trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4daa631c-07b5-4fc3-b8c8-bf59db6bb144",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model_id = \"HenryJJ/tangshi-mixtral\"\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16\n",
    ")\n",
    "\n",
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    base_model_id,  # Mistral, same as before\n",
    "    quantization_config=bnb_config,  # Same quantization config as before\n",
    "    device_map = \"auto\",\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "\n",
    "eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c574d098-ca8b-4854-8c42-f902b35453ce",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_model.config.use_cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2c09d0e8-d592-4492-a523-d4f7bdcf1d1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "你是一个唐诗助手,帮助用户写一首对应要求的唐诗\n",
      "\n",
      "INPUT:\n",
      "作者:李白\n",
      "标签:五言绝句;宫怨\n",
      "\n",
      "OUTPUT:\n",
      "宮詞\n",
      "玉階生白露，夜久侵羅襪。\n",
      "却下水精簾，玲瓏望秋月。\n"
     ]
    }
   ],
   "source": [
    "eval_prompt = \"\"\"<s>你是一个唐诗助手,帮助用户写一首对应要求的唐诗\n",
    "\n",
    "INPUT:\n",
    "作者:李白\n",
    "标签:五言绝句;宫怨\n",
    "\n",
    "OUTPUT:\n",
    "\"\"\"\n",
    "model_input = eval_tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "base_model.eval()\n",
    "with torch.no_grad():\n",
    "    print(eval_tokenizer.decode(base_model.generate(**model_input, max_new_tokens=500)[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00f22b66-49dd-438f-8826-f3240122f83c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
